﻿using Microsoft.Extensions.Primitives;
using System.Net.WebSockets;
using System.Threading;

namespace SimpleX
{
    public class WebsocketHandlerMiddleware
    {
        private readonly RequestDelegate _next;
        private readonly ILogger<WebsocketHandlerMiddleware> _logger;

        public WebsocketHandlerMiddleware(
            RequestDelegate next,
            ILogger<WebsocketHandlerMiddleware> logger
            )
        {
            _next = next;
            _logger = logger;
        }

        /*
         nginx:
            location /ws {
            proxy_pass http://10.171.11.162:8000/ws$is_args$args;
            proxy_http_version 1.1;
            proxy_set_header Upgrade $http_upgrade;
            proxy_set_header Connection $connection_upgrade;
            proxy_read_timeout 300s;
            proxy_send_timeout 300s;
            proxy_connect_timeout 300s;
                }
         */

        public async Task Invoke(HttpContext context)
        {
            if (context.Request.Path == "/ws")
            {
                if (context.WebSockets.IsWebSocketRequest)
                {
                    //请求头必须有 ClientId
                    var clientId = string.Empty;
                    if (context.Request.Headers.TryGetValue("ClientId", out StringValues headerValues))
                    {
                        clientId = headerValues.FirstOrDefault();
                    }

                    //请求头没有 ClientId,则从查询字符串中获取
                    if (string.IsNullOrWhiteSpace(clientId) && context.Request.QueryString.HasValue)
                    {
                        clientId = context.Request.Query["ClientId"];
                    }

                    if (string.IsNullOrWhiteSpace(clientId))
                    {
                        context.Response.StatusCode = 404;
                        return;
                    }

                    WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync();

                    var wsClient = new WebsocketClient
                    {
                        ClientId = clientId,
                        WebSocket = webSocket
                    };
                    try
                    {
                        await Handle(wsClient);
                    }
                    catch (Exception ex)
                    {
                        _logger.LogError(ex, "Echo websocket client {0} err {1}.", clientId, ex.Message);
                        await context.Response.WriteAsync("closed");
                    }
                }
                else
                {
                    context.Response.StatusCode = 404;
                }
            }
            else
            {
                await _next(context);
            }
        }

        private async Task Handle(WebsocketClient webSocket)
        {
            WebsocketClientCollection.RemoveSameId(webSocket);
            WebsocketClientCollection.Add(webSocket);

            _logger.LogInformation($"Websocket client added.{webSocket.ClientId}");
            WebSocketReceiveResult result = null;
            try
            {
                do
                {
                    try
                    {
                        if (webSocket.WebSocket.State == WebSocketState.Open)
                        {
                            var buffer = new byte[1024 * 4];
                            result = await webSocket.WebSocket.ReceiveAsync(new ArraySegment<byte>(buffer), CancellationToken.None);

                            try
                            {
                                if (result.MessageType == WebSocketMessageType.Text)
                                {
                                    var msgString = Encoding.UTF8.GetString(buffer).TrimEnd('\0');
                                    var message = JsonConvert.DeserializeObject<WebsocketMessage<string>>(msgString);
                                    var actions = message.Action.Split('_');
                                    //action ->  格式为 动作_客户端ID,msg为信息
                                    if (actions.Length > 1)
                                    {
                                        await WebsocketClientCollection.Send(actions[1], new WebsocketMessage<string>
                                        {
                                            Action = actions[0],
                                            Data = message.Data,
                                        });
                                    }
                                }
                                else if (result.MessageType == WebSocketMessageType.Close)
                                {
                                    break;
                                }
                            }
                            catch
                            {
                            }
                        }
                    }
                    catch { }
                    await Task.Delay(300);
                }
                while (!result.CloseStatus.HasValue);
            }
            catch { }
            finally
            {
                WebsocketClientCollection.Remove(webSocket);
            }
        }
    }
}